package com.hys.app.framework.ws.impl;

import com.hys.app.framework.context.app.AppTypeEnum;
import com.hys.app.framework.ws.SessionManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.springframework.web.socket.WebSocketSession;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author kingapex
 * @version 1.0
 * @description TODO
 * @data 2022/7/28 17:46
 **/
@Service
public class SessionManagerImpl implements SessionManager {


    private final Logger logger = LoggerFactory.getLogger(getClass());

    private static final Map<Long, WebSocketSession> ADMIN_SESSION_MAP = new ConcurrentHashMap<>();

    private static final Map<Long, WebSocketSession> SELLER_SESSION_MAP = new ConcurrentHashMap<>();

    private static final Map<Long, WebSocketSession> BUYER_SESSION_MAP = new ConcurrentHashMap<>();

    private Map<Long, WebSocketSession> getSessionMap(WebSocketSession session) {

        String appType = (String) session.getAttributes().get("appType");

        AppTypeEnum appTypeEnum = AppTypeEnum.valueOf(appType);

        return getSessionMap(appTypeEnum);
    }

    private Map<Long, WebSocketSession> getSessionMap(AppTypeEnum appType) {


        if (appType == null) {
            logger.error("错误的app type{}", appType);
            return null;
        }

        switch (appType) {
            case Admin:
                return ADMIN_SESSION_MAP;
            case Shop:
                return SELLER_SESSION_MAP;
            case Buyer:
                return BUYER_SESSION_MAP;
            default:
                return null;
        }
    }



    @Override
    public boolean addSession(WebSocketSession session) {
        Long userId = (Long) session.getAttributes().get("userId");
        if (userId == null || userId <= 0 ) {
            logger.error("错误的参数： userId{} ", userId);
            return false;
        }

        Map<Long, WebSocketSession> sessionMap = getSessionMap(session);

        sessionMap.put(userId, session);
        return true;
    }

    @Override
    public boolean removeSession(WebSocketSession session) {
        Long userId = (Long) session.getAttributes().get("userId");
        if (userId == null || userId <= 0 ) {
            logger.error("错误的参数： userId{} ", userId);
            return false;
        }
        Map<Long, WebSocketSession> sessionMap = getSessionMap(session);
        sessionMap.remove(userId);
        return true;
    }

    @Override
    public WebSocketSession getSession(AppTypeEnum appType, Long userId) {
        Map<Long, WebSocketSession> sessionMap = this.getSessionMap(appType);
        return sessionMap.get(userId);
    }


}
